import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import os

from model import WideResNet
from data import get_cifar10_loaders, get_cifar100_loaders
from IAM import IAM_S

def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return 100. * correct / total

if __name__ == "__main__":
    os.environ["TMPDIR"] = "/home/intern/tmp"

    parser = argparse.ArgumentParser()
    parser.add_argument("--noise_scale", default=0.05,       type=float)
    parser.add_argument("--dropout",     default=0.0,        type=float)
    parser.add_argument("--rho",         default=0.1,        type=float)
    parser.add_argument("--epochs",      default=200,        type=int)
    parser.add_argument("--lr",          default=0.1,        type=float)
    parser.add_argument("--dataset",     default="CIFAR-10", type=str)
    args = parser.parse_args()

    epochs = args.epochs
    # CUDA
    device = "cuda" if torch.cuda.is_available() else "cpu"


    if args.dataset == "CIFAR-10":
        train_loader, test_loader = get_cifar10_loaders()
        num_labels = 10
    elif args.dataset == "CIFAR-100":
        train_loader, test_loader = get_cifar100_loaders()
        num_labels = 100
    

    model = WideResNet(depth=16, width_factor=8, dropout=args.dropout, in_channels=3, labels=num_labels)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing = 0.1)
    optimizer = IAM_S(model.parameters(), 
                                    torch.optim.SGD, 
                                    rho=args.rho, 
                                    noise_scale=args.noise_scale,
                                    lr=args.lr, 
                                    momentum=0.9, 
                                    weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epochs//10*3, epochs//10*6,epochs//10*8], gamma=0.2)

    loss_history = []
    error_history = []
    loss_S_history = []
    #Training Loop
    for epoch in range(epochs):
        model.train()
        
        running_loss_orig_accum = 0.0
        running_loss_s_accum = 0.0
        correct_accum = 0
        total_accum = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # define closure for performed in step of IAM-S
            def closure():
                
                outputs = model(inputs) # output of current parameters
                loss = criterion(outputs, targets) # calculate L(θ)
                
                if torch.isnan(loss): # check Nan in train loss
                    print(f"NaN detected in main loss calculation inside closure at epoch {epoch}, batch {batch_idx}.")
                
                loss.backward() 
                return loss, outputs 

            # a step of IAM-S
            loss_original, loss_s = optimizer.step(
                closure_main_loss=closure,
                model_nn_module=model,
                inputs_for_model=inputs
            )
            
            # logging 
            if loss_original is not None and not torch.isnan(loss_original):
                running_loss_orig_accum += loss_original.item()
            if loss_s is not None and not torch.isnan(loss_s):
                running_loss_s_accum += loss_s.item()
            else: 
                if loss_s is None : print(f"Epoch {epoch}, Batch {batch_idx}: Perturbed step was skipped.")
                

        scheduler.step()

        # Average Loss
        avg_loss = running_loss_orig_accum / len(train_loader)
        avg_loss_s = running_loss_s_accum / len(train_loader)
        acc = evaluate(model)
        error = 100 - acc
        error_history.append(error)
        print(f"Epoch: {epoch+1}, test error: {error:.2f}, loss : {avg_loss:.4f}, loss_S : {avg_loss_s:.4f}", flush=True)
        loss_history.append(avg_loss)
        loss_S_history.append(avg_loss_s)
        
print("best error: ", min(error_history))
